import argparse
from dataclasses import dataclass
from typing import TypeVar, List, Union


@dataclass
class Arg:
    """
    Class for tracking command-line arg configurations
    """

    name: str
    default: int
    type: TypeVar
    action: str = None

    def add_arg(self, parser_obj):
        if self.type == bool:
            parser_obj.add_argument(
                "--" + self.name, default=self.default, action=self.action
            )
        else:
            parser_obj.add_argument(
                "--" + self.name, default=self.default, type=self.type, nargs="?"
            )


def parse(valid_args: List[str]) -> List[Union[int, float]]:
    """
    The parse() function receives a list of strings with the desired arguments and
    returns a list containing the values of those arguments.

    The parse function uses argparse, allowing users to parse arguments through the CLI
    Using any argument that is not part of the list of valid args will result in a ValueError.
    """
    # List valid args
    tkml_args = {
        # General args
        # Batch size for the input to the model that will be used for benchmarking
        "batch_size": Arg("batch_size", default=1, type=int),
        # Maximum sequence length for the model's input;
        # also the input sequence length that will be used for benchmarking
        "max_seq_length": Arg("max_seq_length", default=128, type=int),
        # Maximum sequence length for the model's audio input;
        # also the input sequence length that will be used for benchmarking
        "max_audio_seq_length": Arg("max_audio_seq_length", default=25600, type=int),
        # Height of the input image that will be used for benchmarking
        "height": Arg("height", default=224, type=int),
        # Number of channels in the input image that will be used for benchmarking
        "num_channels": Arg("num_channels", default=3, type=int),
        # Width of the input image that will be used for benchmarking
        "width": Arg("width", default=224, type=int),
        # Args for Graph Neural Networks
        # k is Chebyshev filter is a polynomial-based filter that provides
        # a fast and localized spectral response on graph-structured data
        # example: used in chebconv
        "k": Arg("k", default=2, type=int),
        # alpha represents teleport probability used in the propagation step,
        # determining the amount of information considered from distant nodes in the graph
        "alpha": Arg("alpha", default=2.2, type=float),
        # Out channels represent the number of output features or
        # dimensions generated by the graph neural network layers.
        "out_channels": Arg("out_channels", default=1500, type=int),
        # Num_layers represent the number of hidden layers or depth neural
        # network, affecting its expressiveness and complexity.
        "num_layers": Arg("num_layers", default=8, type=int),
        # In_channels represent the number of input features per node in the
        # graph, defining the input dimensionality for the network.
        "in_channels": Arg("in_channels", default=1433, type=int),
        # Pretrained indicates whether pretrained weights should be used on the model
        "pretrained": Arg("pretrained", default=False, type=bool, action="store_true"),
        # Path on the filesystem to a copy of the model
        # Useful when models have special licensing terms and we can't directly
        # provide access
        "model_path": Arg("model_path", default=None, type=str),
        # True: use an LLM's KV-cache (ie, token phase). False: disable the
        # KV-cache (ie, prefill phase).
        "kvcache": Arg("kvcache", default=False, type=bool, action="store_true"),
    }

    # Create parser that accepts only the args received as part of valid_args
    parser = argparse.ArgumentParser()
    for arg_name in valid_args:
        if arg_name not in tkml_args.keys():
            raise KeyError(
                f"{arg_name} is not a valid arg. Valid args are {str(tkml_args.keys())}."
            )
        tkml_args[arg_name].add_arg(parser)

    # Parse arg and return args as a list
    args = vars(parser.parse_args())
    parsed_args = [args[arg] for arg in valid_args]
    return parsed_args
